"""
The code for GridSearchReduction wraps the source class
fairlearn.reductions.GridSearch
available in the https://github.com/fairlearn/fairlearn library
licensed under the MIT Licencse, Copyright Microsoft Corporation
"""
from logging import warning

import fairlearn.reductions as red
import pandas as pd

from aif360.algorithms import Transformer
from aif360.sklearn.inprocessing import GridSearchReduction as skGridSearchRed


class GridSearchReduction(Transformer):
    """Grid search reduction for fair classification or regression.

    Grid search is an in-processing technique that can be used for fair
    classification or fair regression. For classification it reduces fair
    classification to a sequence of cost-sensitive classification problems,
    returning the deterministic classifier with the lowest empirical error
    subject to fair classification constraints [#agarwal18]_ among the
    candidates searched. For regression it uses the same priniciple to return a
    deterministic regressor with the lowest empirical error subject to the
    constraint of bounded group loss [#agarwal19]_.

    References:
        .. [#agarwal18] `A. Agarwal, A. Beygelzimer, M. Dudik, J. Langford, and
           H. Wallach, "A Reductions Approach to Fair Classification,"
           International Conference on Machine Learning, 2018.
           <https://arxiv.org/abs/1803.02453>`_
        .. [#agarwal19] `A. Agarwal, M. Dudik, and Z. Wu, "Fair Regression:
           Quantitative Definitions and Reduction-based Algorithms,"
           International Conference on Machine Learning, 2019.
           <https://arxiv.org/abs/1905.12843>`_
    """
    def __init__(self,
                 estimator,
                 constraints,
                 prot_attr=None,
                 constraint_weight=0.5,
                 grid_size=10,
                 grid_limit=2.0,
                 grid=None,
                 drop_prot_attr=True,
                 loss="ZeroOne",
                 min_val=None,
                 max_val=None):
        """
        Args:
            estimator: An estimator implementing methods ``fit(X, y,
                sample_weight)`` and ``predict(X)``, where ``X`` is the matrix
                of features, ``y`` is the vector of labels, and
                ``sample_weight`` is a vector of weights; labels ``y`` and
                predictions returned by ``predict(X)`` are either 0 or 1 -- e.g.
                scikit-learn classifiers/regressors.
            constraints (str or fairlearn.reductions.Moment): If string, keyword
                denoting the :class:`fairlearn.reductions.Moment` object
                defining the disparity constraints -- e.g., "DemographicParity"
                or "EqualizedOdds". For a full list of possible options see
                `self.model.moments`. Otherwise, provide the desired
                :class:`~fairlearn.reductions.Moment` object defining the
                disparity constraints.
            prot_attr: String or array-like column indices or column names
                of protected attributes.
            constraint_weight: When the ``selection_rule`` is
                "tradeoff_optimization" (default, no other option currently)
                this float specifies the relative weight put on the constraint
                violation when selecting the best model. The weight placed on
                the error rate will be ``1-constraint_weight``.
            grid_size (int): The number of Lagrange multipliers to generate in
                the grid.
            grid_limit (float): The largest Lagrange multiplier to generate. The
                grid will contain values distributed between ``-grid_limit`` and
                ``grid_limit`` by default.
            grid (pandas.DataFrame): Instead of supplying a size and limit for
                the grid, users may specify the exact set of Lagrange
                multipliers they desire using this argument in a DataFrame.
            drop_prot_attr (bool): Flag indicating whether to drop protected
                attributes from training data.
            loss (str): String identifying loss function for constraints.
                Options include "ZeroOne", "Square", and "Absolute."
            min_val: Loss function parameter for "Square" and "Absolute,"
                typically the minimum of the range of y values.
            max_val: Loss function parameter for "Square" and "Absolute,"
                typically the maximum of the range of y values.
        """
        super(GridSearchReduction, self).__init__()

        #init model, set prot_attr during fit
        if prot_attr is None:
            prot_attr = []
        self.model = skGridSearchRed(prot_attr, estimator, constraints,
                constraint_weight, grid_size, grid_limit, grid, drop_prot_attr,
                loss, min_val, max_val)


    def fit(self, dataset):
        """Learns model with less bias

        Args:
            dataset : Dataset containing true output.

        Returns:
            GridSearchReduction: Returns self.
        """
        #set prot_attr
        if len(self.model.prot_attr) == 0:
            self.model.prot_attr = dataset.protected_attribute_names

        X_df = pd.DataFrame(dataset.features, columns=dataset.feature_names)
        Y = dataset.labels

        self.model.fit(X_df, Y)

        return self


    def predict(self, dataset):
        """Obtain the predictions for the provided dataset using the model
        learned.

        Args:
            dataset: Dataset containing output values that need to be
                transformed.

        Returns:
            dataset: Transformed dataset.
        """
        X_df = pd.DataFrame(dataset.features, columns=dataset.feature_names)

        dataset_new = dataset.copy()
        dataset_new.labels = self.model.predict(X_df).reshape(-1, 1)

        if isinstance(self.model.moment, red.ClassificationMoment):
            fav = int(dataset.favorable_label)
            try:
                # Probability of favorable label
                scores = self.model.predict_proba(X_df)[:, fav]
                dataset_new.scores = scores.reshape(-1, 1)
            except (AttributeError, NotImplementedError):
                warning("dataset.scores not updated, underlying model does not "
                        "support predict_proba")

        return dataset_new
